# Copyright 2022 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import tempfile
import textwrap
import unittest
from unittest.mock import Mock
from unittest.mock import patch

from absl.testing import parameterized
from google.protobuf import json_format
from kfp.client import auth
from kfp.client import client
from kfp.compiler import Compiler
from kfp.dsl import component
from kfp.dsl import pipeline
from kfp.pipeline_spec import pipeline_spec_pb2
import kfp_server_api
import yaml


class TestValidatePipelineName(parameterized.TestCase):

    @parameterized.parameters([
        'pipeline',
        'my-pipeline',
        'my-pipeline-1',
        '1pipeline',
        'pipeline1',
        'my_pipeline',
        "person's-pipeline",
        'my pipeline',
        'pipeline.yaml',
    ])
    def test_valid(self, name: str):
        client.validate_pipeline_display_name(name)

    @parameterized.parameters(['', '   ', '\t'])
    def test_invalid(self, name: str):
        with self.assertRaisesRegex(
                ValueError,
                'Invalid pipeline name. Pipeline name cannot be empty or contain only whitespace.'
        ):
            client.validate_pipeline_display_name(name)


class TestOverrideCachingOptions(parameterized.TestCase):

    def test_override_caching_of_multiple_components(self):

        @component
        def hello_word(text: str) -> str:
            return text

        @component
        def to_lower(text: str) -> str:
            return text.lower()

        @pipeline(
            name='sample two-step pipeline',
            description='a minimal two-step pipeline')
        def pipeline_with_two_component(text: str = 'hi there'):

            component_1 = hello_word(text=text).set_caching_options(True)
            component_2 = to_lower(
                text=component_1.output).set_caching_options(False)

        with tempfile.TemporaryDirectory() as tempdir:
            temp_filepath = os.path.join(tempdir, 'hello_world_pipeline.yaml')
            Compiler().compile(
                pipeline_func=pipeline_with_two_component,
                package_path=temp_filepath)

            with open(temp_filepath, 'r') as f:
                pipeline_obj = yaml.safe_load(f)
                pipeline_spec = json_format.ParseDict(
                    pipeline_obj, pipeline_spec_pb2.PipelineSpec())
                client._override_caching_options(pipeline_spec, True)
                pipeline_obj = json_format.MessageToDict(pipeline_spec)
                self.assertTrue(pipeline_obj['root']['dag']['tasks']
                                ['hello-word']['cachingOptions']['enableCache'])
                self.assertTrue(pipeline_obj['root']['dag']['tasks']['to-lower']
                                ['cachingOptions']['enableCache'])


class TestExtractPipelineYAML(parameterized.TestCase):

    def test_extract_pipeline_yaml_single_doc(self):

        with tempfile.TemporaryDirectory() as tempdir:
            temp_filepath = os.path.join(tempdir, 'single_doc_pipeline.yaml')
            with open(temp_filepath, 'w') as f:
                f.write(
                    textwrap.dedent('''
                        components:
                          comp-foo:
                            executorLabel: exec-foo
                        deploymentSpec:
                          executors:
                            exec-foo:
                              container:
                                command:
                                - sh
                                - -c
                                - cat /data/file.txt
                                image: alpine
                        pipelineInfo:
                          name: my-pipeline
                        root:
                          dag:
                            tasks:
                              foo:
                                componentRef:
                                  name: comp-foo
                                taskInfo:
                                  name: foo
                        schemaVersion: 2.1.0
                        sdkVersion: kfp-2.0.0-beta.13
                        '''))

            pipeline_dict = client._extract_pipeline_yaml(
                temp_filepath).to_dict()
            self.assertEqual('my-pipeline',
                             pipeline_dict['pipelineInfo']['name'])

    def test_extract_pipeline_yaml_multiple_docs(self):

        with tempfile.TemporaryDirectory() as tempdir:
            temp_filepath = os.path.join(tempdir, 'multi_docs_pipeline.yaml')
            with open(temp_filepath, 'w') as f:
                f.write(
                    textwrap.dedent('''
                        components:
                          comp-foo:
                            executorLabel: exec-foo
                        deploymentSpec:
                          executors:
                            exec-foo:
                              container:
                                command:
                                - sh
                                - -c
                                - cat /data/file.txt
                                image: alpine
                        pipelineInfo:
                          name: my-pipeline
                        root:
                          dag:
                            tasks:
                              foo:
                                componentRef:
                                  name: comp-foo
                                taskInfo:
                                  name: foo
                        schemaVersion: 2.1.0
                        sdkVersion: kfp-2.0.0-beta.13
                        ---
                        platforms:
                          kubernetes:
                            deploymentSpec:
                              executors:
                                exec-foo:
                                  pvcMount:
                                  - mountPath: /data
                                    constant: my-pvc
                        '''))

            pipeline_dict = client._extract_pipeline_yaml(
                temp_filepath).to_dict()
            self.assertEqual(
                'my-pipeline',
                pipeline_dict['pipeline_spec']['pipelineInfo']['name'])
            self.assertEqual(
                'my-pvc', pipeline_dict['platform_spec']['platforms']
                ['kubernetes']['deploymentSpec']['executors']['exec-foo']
                ['pvcMount'][0]['constant'])


class TestClient(parameterized.TestCase):

    def setUp(self):
        self.client = client.Client(namespace='ns1')

    def test_wait_for_run_completion_invalid_token_should_raise_error(self):
        with self.assertRaises(kfp_server_api.ApiException):
            with patch.object(
                    self.client._run_api,
                    'get_run',
                    side_effect=kfp_server_api.ApiException) as mock_get_run:
                self.client.wait_for_run_completion(
                    run_id='foo', timeout=1, sleep_duration=0)
                mock_get_run.assert_called_once()

    def test_wait_for_run_completion_expired_access_token(self):
        with patch.object(self.client._run_api, 'get_run') as mock_get_run:
            # We need to iterate through multiple side effects in order to test this logic.
            mock_get_run.side_effect = [
                Mock(state='unknown state'),
                kfp_server_api.ApiException(status=401),
                Mock(state='succeeded'),
            ]

            with patch.object(self.client, '_refresh_api_client_token'
                             ) as mock_refresh_api_client_token:
                self.client.wait_for_run_completion(
                    run_id='foo', timeout=1, sleep_duration=0)
                mock_get_run.assert_called_with(run_id='foo')
                mock_refresh_api_client_token.assert_called_once()

    def test_wait_for_run_completion_valid_token(self):
        with patch.object(self.client._run_api, 'get_run') as mock_get_run:
            mock_get_run.return_value = Mock(state='succeeded')
            response = self.client.wait_for_run_completion(
                run_id='foo', timeout=1, sleep_duration=0)
            mock_get_run.assert_called_once_with(run_id='foo')
            assert response == mock_get_run.return_value

    def test_wait_for_run_completion_run_timeout_should_raise_error(self):
        with self.assertRaises(TimeoutError):
            with patch.object(self.client._run_api, 'get_run') as mock_get_run:
                mock_get_run.return_value = Mock(run=Mock(status='foo'))
                self.client.wait_for_run_completion(
                    run_id='foo', timeout=1, sleep_duration=0)
                mock_get_run.assert_called_once_with(run_id='foo')

    @patch('kfp.Client.get_experiment', side_effect=ValueError)
    def test_create_experiment_no_experiment_should_raise_error(
            self, mock_get_experiment):
        with self.assertRaises(ValueError):
            self.client.create_experiment(name='foo', namespace='ns1')
            mock_get_experiment.assert_called_once_with(
                name='foo', namespace='ns1')

    @patch('kfp.Client.get_experiment', return_value=Mock(id='foo'))
    @patch('kfp.Client._get_url_prefix', return_value='/pipeline')
    def test_create_experiment_existing_experiment(self, mock_get_url_prefix,
                                                   mock_get_experiment):
        self.client.create_experiment(name='foo')
        mock_get_experiment.assert_called_once_with(
            experiment_name='foo', namespace='ns1')
        mock_get_url_prefix.assert_called_once()

    @patch('kfp_server_api.V2beta1Experiment')
    @patch(
        'kfp.Client.get_experiment',
        side_effect=ValueError('No experiment is found with name'))
    @patch('kfp.Client._get_url_prefix', return_value='/pipeline')
    def test__create_experiment_name_not_found(self, mock_get_url_prefix,
                                               mock_get_experiment,
                                               mock_api_experiment):
        # experiment with the specified name is not found, so a new experiment
        # is created.
        with patch.object(
                self.client._experiment_api,
                'create_experiment',
                return_value=Mock(
                    experiment_id='foo')) as mock_create_experiment:
            self.client.create_experiment(name='foo')
            mock_get_experiment.assert_called_once_with(
                experiment_name='foo', namespace='ns1')
            mock_api_experiment.assert_called_once()
            mock_create_experiment.assert_called_once()
            mock_get_url_prefix.assert_called_once()

    def test_get_experiment_no_experiment_id_or_name_should_raise_error(self):
        with self.assertRaises(ValueError):
            self.client.get_experiment()

    @patch('kfp.Client.get_user_namespace', return_value=None)
    def test_get_experiment_does_not_exist_should_raise_error(
            self, mock_get_user_namespace):
        with self.assertRaises(ValueError):
            with patch.object(
                    self.client._experiment_api,
                    'list_experiments',
                    return_value=Mock(
                        experiments=None)) as mock_list_experiments:
                self.client.get_experiment(experiment_name='foo')
                mock_list_experiments.assert_called_once()
                mock_get_user_namespace.assert_called_once()

    @patch('kfp.Client.get_user_namespace', return_value=None)
    def test_get_experiment_multiple_experiments_with_name_should_raise_error(
            self, mock_get_user_namespace):
        with self.assertRaises(ValueError):
            with patch.object(
                    self.client._experiment_api,
                    'list_experiments',
                    return_value=Mock(
                        experiments=['foo', 'foo'])) as mock_list_experiments:
                self.client.get_experiment(experiment_name='foo')
                mock_list_experiments.assert_called_once()
                mock_get_user_namespace.assert_called_once()

    def test_get_experiment_with_experiment_id(self):
        with patch.object(self.client._experiment_api,
                          'get_experiment') as mock_get_experiment:
            self.client.get_experiment(experiment_id='foo')
            mock_get_experiment.assert_called_once_with(experiment_id='foo')

    def test_get_experiment_with_experiment_name_and_namespace(self):
        with patch.object(self.client._experiment_api,
                          'list_experiments') as mock_list_experiments:
            self.client.get_experiment(experiment_name='foo', namespace='ns1')
            mock_list_experiments.assert_called_once()

    @patch('kfp.Client.get_user_namespace', return_value=None)
    def test_get_experiment_with_experiment_name_and_no_namespace(
            self, mock_get_user_namespace):
        with patch.object(self.client._experiment_api,
                          'list_experiments') as mock_list_experiments:
            self.client.get_experiment(experiment_name='foo')
            mock_list_experiments.assert_called_once()
            mock_get_user_namespace.assert_called_once()

    @patch('kfp_server_api.HealthzServiceApi.get_healthz')
    def test_get_kfp_healthz(self, mock_get_kfp_healthz):
        mock_get_kfp_healthz.return_value = json.dumps([{'foo': 'bar'}])
        response = self.client.get_kfp_healthz()
        mock_get_kfp_healthz.assert_called_once()
        assert (response == mock_get_kfp_healthz.return_value)

    @patch(
        'kfp_server_api.HealthzServiceApi.get_healthz',
        side_effect=kfp_server_api.ApiException)
    def test_get_kfp_healthz_should_raise_error(self, mock_get_kfp_healthz):
        with self.assertRaises(TimeoutError):
            self.client.get_kfp_healthz(sleep_duration=0)
            mock_get_kfp_healthz.assert_called()

    def test_upload_pipeline_without_name(self):

        @component
        def return_bool(boolean: bool) -> bool:
            return boolean

        @pipeline(name='test-upload-without-name', description='description')
        def pipeline_test_upload_without_name(boolean: bool = True):
            return_bool(boolean=boolean)

        with patch.object(self.client._upload_api,
                          'upload_pipeline') as mock_upload_pipeline:
            with patch.object(auth, 'is_ipython', return_value=False):
                with tempfile.TemporaryDirectory() as tmp_path:
                    pipeline_test_path = os.path.join(tmp_path, 'test.yaml')
                    Compiler().compile(
                        pipeline_func=pipeline_test_upload_without_name,
                        package_path=pipeline_test_path)
                    self.client.upload_pipeline(
                        pipeline_package_path=pipeline_test_path,
                        description='description',
                        namespace='ns1')
                    mock_upload_pipeline.assert_called_once_with(
                        pipeline_test_path,
                        name='test-upload-without-name',
                        description='description',
                        namespace='ns1')

    @parameterized.parameters([
        'pipeline',
        'my-pipeline',
        'my-pipeline-1',
        '1pipeline',
        'pipeline1',
        'my_pipeline',
        "person's-pipeline",
        'my pipeline',
        'pipeline.yaml',
    ])
    def test_upload_pipeline_with_name(self, pipeline_name):
        with patch.object(self.client._upload_api,
                          'upload_pipeline') as mock_upload_pipeline:
            with patch.object(auth, 'is_ipython', return_value=False):
                self.client.upload_pipeline(
                    pipeline_package_path='fake.yaml',
                    pipeline_name=pipeline_name,
                    description='description',
                    namespace='ns1')
                mock_upload_pipeline.assert_called_once_with(
                    'fake.yaml',
                    name=pipeline_name,
                    description='description',
                    namespace='ns1')

    @parameterized.parameters([
        '',
        '   ',
        '\t',
    ])
    def test_upload_pipeline_with_name_invalid(self, pipeline_name):
        with patch.object(self.client._upload_api,
                          'upload_pipeline') as mock_upload_pipeline:
            with patch.object(auth, 'is_ipython', return_value=False):
                with self.assertRaisesRegex(
                        ValueError,
                        'Invalid pipeline name. Pipeline name cannot be empty or contain only whitespace.'
                ):
                    self.client.upload_pipeline(
                        pipeline_package_path='fake.yaml',
                        pipeline_name=pipeline_name,
                        description='description',
                        namespace='ns1')


if __name__ == '__main__':
    unittest.main()
